from re import L
import gym

import os
import wandb
import time
from wandb.integration.sb3 import WandbCallback
from stable_baselines3.common.callbacks import CheckpointCallback
import hydra
from omegaconf import DictConfig, OmegaConf
from pathlib import Path
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import save_metrics_to_wandb
import pickle
import matplotlib.pyplot as plt
from copy import deepcopy as cp
from stable_baselines3.common.utils import safe_mean
from test_spreadsheet import write_metrics_to_spreadsheet

from utilities import *
from omegaconf import OmegaConf,open_dict
from collections import OrderedDict


def parse_arguments(args):

    seed_everything(args.seed)

    if args.is_test:
        args.n_train_steps_combined = 2000 if not args.n_train_steps_combined == 0 else args.n_train_steps_combined
        args.eval_freq = 2
        args.n_eval_episodes = 3
        args.save_freq = 1000

    return args


def init_wandb(args):
    if args.use_wandb:
        # wandb.config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True)

        args.wandb_run_id = wandb.util.generate_id()

        # save cwd to make it easier to find specific runs on cluster
        args.working_directory = os.getcwd()

        # get run_name from ID and tags
        run_name_base = os.getcwd().split("/")[-1]

        # get combined run name
        run_name = f"{args.env}_{run_name_base}_{args.wandb_tag}_{int(time.time())}"

        wandb.tensorboard.patch(root_logdir=f"runs/{args.wandb_run_id}")

        _ = wandb.init(
            id=args.wandb_run_id,
            name=run_name,
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            config=OmegaConf.to_container(args, resolve=True, throw_on_missing=True),
            sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
            monitor_gym=False,  # auto-upload the videos of agents playing the game
            save_code=True,  # optional
            tags=[args.wandb_tag],
            resume='allow'
        )

    return args


def setup_learning_adv(wandb_cb, adv_model, args):
    # setup adversary
    adv_callbacks = []
    adv_callbacks.append(wandb_cb) if args.use_wandb else None
    adv_checkpoint_callback = CheckpointCallback(
        save_freq=args.save_freq,
        save_path="logs/",
        name_prefix="adv_rl_model"
    )
    adv_callbacks.append(adv_checkpoint_callback)

    adv_total_timesteps, adv_callback = adv_model._setup_learn(
        total_timesteps=args.n_steps_total_adv,
        eval_env=adv_model.env,
        callback=adv_callbacks,
        eval_freq=args.eval_freq,
        n_eval_episodes=args.n_eval_episodes,
        log_path=None,
        reset_num_timesteps=True,
        tb_log_name="",
    )

    return adv_total_timesteps, adv_callback

def setup_learning_victim(wandb_cb, victim_model, args):

    # setup victim
    victim_callbacks = []
    victim_callbacks.append(wandb_cb) if args.use_wandb else None
    victim_checkpoint_callback = CheckpointCallback(
        save_freq=args.save_freq,
        save_path="logs/",
        name_prefix="victim_rl_model"
    )
    victim_callbacks.append(victim_checkpoint_callback)

    victim_iteration = 0
    victim_total_timesteps, victim_callback = victim_model._setup_learn(
        total_timesteps=args.n_steps_total_vic,
        eval_env= victim_model.env,
        callback=victim_callbacks,
        eval_freq=args.eval_freq,
        n_eval_episodes=args.n_eval_episodes,
        log_path=None,
        reset_num_timesteps=True,
        tb_log_name=""
    )

    return victim_total_timesteps, victim_callback, victim_iteration

def learn_step_adv(adv_model, meta_step, log_interval, adv_total_timesteps, adv_callback, args):

    while adv_model.num_timesteps < (meta_step+1) * args.n_steps_vic_per_meta * args.n_adv_steps_per_vic_step:
        adv_model.num_timesteps += 1
        adv_rollout = adv_model.collect_rollouts(
            adv_model.env,
            train_freq=adv_model.train_freq,
            action_noise=adv_model.action_noise,
            callback=adv_callback,
            learning_starts=adv_model.learning_starts,
            replay_buffer=adv_model.replay_buffer,
            log_interval=log_interval,
        )

        if adv_rollout.continue_training is False:
            break

        if adv_model.num_timesteps > 0 and adv_model.num_timesteps > adv_model.learning_starts:
            adv_gradient_steps = adv_model.gradient_steps if adv_model.gradient_steps >= 0 else adv_rollout.episode_timesteps
            if adv_gradient_steps > 0:
                adv_model.train(batch_size=adv_model.batch_size, gradient_steps=adv_gradient_steps)

    adv_model.logger.record("meta_step", meta_step)
    adv_model.logger.record("combined_step",
                               adv_model.num_timesteps + adv_model.env.envs[0].victim_agent_model.num_timesteps)
    adv_model._dump_logs()


def learn_step_victim(victim_model, meta_step, log_interval, victim_iteration, victim_total_timesteps, victim_callback, args):
    if args.n_steps_total_vic == 0:
        return

    # train victim one iteration
    while victim_model.num_timesteps < (meta_step+1)*args.n_steps_vic_per_meta:
        victim_continue_training = victim_model.collect_rollouts(victim_model.env, victim_callback,
                                                                 victim_model.rollout_buffer,
                                                                 n_rollout_steps=victim_model.n_steps)

        if victim_continue_training is False:
            break

        victim_iteration += 1
        victim_model._update_current_progress_remaining(victim_model.num_timesteps, victim_total_timesteps)

        victim_model.train()

    # Display training infos
    fps = int((victim_model.num_timesteps - victim_model._num_timesteps_at_start) / (
            time.time() - victim_model.start_time))
    victim_model.logger.record("time/iterations", victim_iteration, exclude="tensorboard")
    if len(victim_model.ep_info_buffer) > 0 and len(victim_model.ep_info_buffer[0]) > 0:
        victim_model.logger.record("rollout/ep_rew_mean",
                                   safe_mean([ep_info["r"] for ep_info in victim_model.ep_info_buffer]))
        victim_model.logger.record("rollout/ep_len_mean",
                                   safe_mean([ep_info["l"] for ep_info in victim_model.ep_info_buffer]))
    victim_model.logger.record("time/fps", fps)
    victim_model.logger.record("meta_step", meta_step)
    victim_model.logger.record("combined_step", victim_model.num_timesteps+victim_model.env.envs[0].model_adv.num_timesteps)
    victim_model.logger.record("time/time_elapsed", int(time.time() - victim_model.start_time),
                               exclude="tensorboard")
    victim_model.logger.record("time/total_timesteps", victim_model.num_timesteps, exclude="tensorboard")
    victim_model.logger.dump(step=victim_model.num_timesteps)


def update_callbacks(adv_callback, victim_callback, args):
    adv_eval_callback = None
    victim_eval_callback = None

    for cb in adv_callback.callbacks:
        if "Eval" in str(cb):
            adv_eval_callback = cb
            adv_callback.callbacks.remove(cb)
            break

    for cb in victim_callback.callbacks:
        if "Eval" in str(cb):
            victim_eval_callback = cb
            victim_callback.callbacks.remove(cb)
            break

    adv_eval_callback.args = args
    victim_eval_callback.args = args

    return adv_eval_callback, victim_eval_callback, adv_callback, victim_callback


def update_agents(victim_model, adv_model):
    for _env in victim_model.env.envs:
            _env.model_adv = adv_model
    adv_model.env.envs[0].victim_agent_model = victim_model
    return victim_model, adv_model


def train(victim_model, adv_model, args):

    # general setup
    log_interval = args.save_freq # shouldnt have any effect
    wandb_cb = WandbCallback() if args.use_wandb else None

    if args.n_adv_steps_per_vic_step == "all":
        print("only training adversary agent")
        args.n_meta_episodes = int(np.ceil(args.n_train_steps_combined/args.n_steps_vic_per_meta))
        args.n_steps_total_vic = 0
        args.n_adv_steps_per_vic_step = 1
        args.n_steps_total_adv = args.n_train_steps_combined
    else:
        args.n_meta_episodes = int(np.ceil(args.n_train_steps_combined / (args.n_steps_vic_per_meta * (1 + args.n_adv_steps_per_vic_step))))
        args.n_steps_total_vic = int(np.ceil(args.n_meta_episodes * args.n_steps_vic_per_meta))
        args.n_steps_total_adv = int(np.ceil(args.n_meta_episodes * args.n_steps_vic_per_meta * args.n_adv_steps_per_vic_step))

    n_steps_explotability = 300000
    if args.evaluate_exploitability:
        args.n_steps_total_vic += n_steps_explotability

    adv_total_timesteps, adv_callback = setup_learning_adv(wandb_cb, adv_model, args)
    victim_total_timesteps, victim_callback, victim_iteration = setup_learning_victim(wandb_cb, victim_model, args)

    adv_callback.on_training_start(locals(), globals())
    victim_callback.on_training_start(locals(), globals())

    adv_eval_callback, victim_eval_callback, adv_callback, victim_callback = update_callbacks(adv_callback, victim_callback, args)

    victim_eval_callback.trigger_eval_manually()
    adv_eval_callback.trigger_eval_manually()

    for meta_step in range(args.n_meta_episodes):

        print(f"new meta ep: meta_ep:{meta_step} victim_step:{victim_model.num_timesteps} adv_step:{adv_model.num_timesteps}")

        if victim_total_timesteps > n_steps_explotability: 
            learn_step_victim(victim_model, meta_step, log_interval, victim_iteration, victim_total_timesteps, victim_callback, args)
        victim_eval_callback.on_step()

        learn_step_adv(adv_model, meta_step, log_interval, adv_total_timesteps, adv_callback, args)
        adv_eval_callback.on_step()

    if args.evaluate_exploitability:
        for i in range(int(np.ceil(n_steps_explotability/args.n_steps_vic_per_meta))):
            learn_step_victim(victim_model, i, log_interval, victim_iteration, victim_total_timesteps, victim_callback, args)
            victim_eval_callback.on_step()

    adv_callback.on_training_end()
    victim_callback.on_training_end()

def evaluate(args, noise=0, render=False):

    print(f"running evaluation with noise {noise}")

    args_for_exp = cp(args)
    args_for_exp.victim_noise_sigma = noise

    victim_model, model_adv = get_setup(args_for_exp,
                                        timestep_victim= None,
                                        timestep_adv= None if (args.mnp_attack or args.no_attack or not args.train) else "latest")

    # potentially put environment into render mode
    if render:
        model_adv.env.envs[0].render = True

    # possible change to MNP attack mode
    model_adv.env.envs[0].run_mnp_attack = True if args_for_exp.mnp_attack else False

    # possible change to effect no attack
    if args.no_attack:
        model_adv.env.envs[0].action_scale = 0

    n_episodes = args_for_exp.n_eval_episodes if not render else 5

    metrics = evaluate_policy(model_adv, model_adv.env.envs[0], n_eval_episodes=n_episodes, deterministic=True, render=False, return_episode_rewards=True)

    # print(f"max budget used: {np.max(metrics['distance_inf_to_true_n_max'])}")

    if not render:
        write_metrics_to_spreadsheet(metrics, args_for_exp)
    else:

        fps = int(1 / model_adv.env.envs[0].victim_env.dt * 2) if "Pendulum" in args.env else int(3 / model_adv.env.envs[0].victim_env.dt)

        if "HalfCheetah" in str(model_adv.env.envs[0]):
            fps = int(fps/10)

        if "Hopper" in str(model_adv.env.envs[0]):
            fps = 50

        process_recorded_frames(model_adv.env.envs[0].frames_true,
                                model_adv.env.envs[0].frames_seen,
                                fps=fps,
                                args=args,
                                log_title=str(f"combined_step_{model_adv.num_timesteps + victim_model.num_timesteps}_advreward_{np.mean(metrics['rewards']).round(2)}"))
    return metrics


def evaluate_detection(args):

    print("running evaluate detection")

    victim_model, model_adv = get_setup(args,
                                        timestep_victim= None,
                                        timestep_adv= None if (args.mnp_attack or args.no_attack) else "latest")

    # turn worldmodel on
    victim_model.env.envs[0].use_worldmodel_to_shutoff = True

    # potentially switch to NP attack
    if args.mnp_attack:
        victim_model.env.envs[0].run_mnp_attack = True

    # potentially switch off adversary
    if args.no_attack:
        victim_model.env.envs[0].model_adv.env.envs[0].action_scale = 0

    n_episodes_to_run = args.n_eval_episodes

    metrics = evaluate_policy(victim_model, victim_model.env.envs[0],
                              n_eval_episodes=n_episodes_to_run,
                              deterministic=True,
                              render=False,
                              return_episode_rewards=True)

    write_metrics_to_spreadsheet(metrics, args)


def run_evaluation(args):
    # this is temporary
    if "experiment_id_new" not in args.keys():
        with open_dict(args):
            args["experiment_id_new"] = args.experiment_id.split("_")[0]+"_"+args.experiment_id.split("_")[1]

    # if necessary, evaluate for all noise levels and save
    if args.experiment_id.split("_")[0] == "cotrain" or args.mnp_attack or args.no_attack or args.action_scale==-1:
        all_noises = [args.victim_noise_sigma]
    else:
        # TODO make sure that eval is run correctly for noise
        # TODO figure out why this damn token expired again...
        all_noises = [args.victim_noise_sigma] # re add action_scale here for noise evaluation

    for noise in all_noises:
        evaluate(args, noise=noise)

def run_post_evaluation(args):
    if "adv_config" not in args.keys():
        with open_dict(args):
            args.adv_config = 0

    args.wandb_tag = args.wandb_tag + "_postprocess"

    run_evaluation(args)

def run_post_render(args):
    if "adv_config" not in args.keys():
        with open_dict(args):
            args.adv_config = 0

    wandb.init(id=args.wandb_run_id, project=args.wandb_project_name, resume="must")
    evaluate(args, render=True)
    wandb.finish()

@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(args : DictConfig):

    args = parse_arguments(args)

    victim_model, adv_model = get_setup(args)

    if args.train:
        args = init_wandb(args)

        if args.train or args.render:
            pickle.dump(args, open("args_used.pickle", "wb"))

        train(victim_model, adv_model, args)
        wandb.finish()
        wandb.tensorboard.unpatch()
        wandb.tensorboard.reset_state()
    if args.evaluate:
        run_evaluation(args)
    if args.render:
        args = init_wandb(args)

        if args.train or args.render:
            pickle.dump(args, open("args_used.pickle", "wb"))

        evaluate(args, render=True)
        wandb.finish()
        wandb.tensorboard.unpatch()
        wandb.tensorboard.reset_state()

    # adv_model.env.envs[0].save_transitions_to_file()

    # import pdb
    # pdb.set_trace()

    # import numpy as np
    # arrts = adv_model.env.envs[0].tolerance_statistics
    # values, counts = np.unique(arrts, return_counts=True)
    # print(counts)

    print("DONE DONE DONE", flush=True)

if __name__ == "__main__":
    time_before = time.time()
    main()
    print(f"total_time: {time.time() - time_before}")
